Large Batch Experience Replay

1 Overview

Large Batch Experience Replay (LaBER) was proposed by T. Lahire et al.1

The authors theoretically derived the best sampling probability \( p_i ^{\ast} \) to minimize performance variance;

\[ p_i^{\ast} \propto \| \nabla _{\theta} L(Q_{\theta}(s_i, a_i), y_i) \| \text{,}\]

where \(L(\cdot,\, \cdot)\) is a loss function.

This requires full backpropagation and is costful, so that the authors proposed surrogate priority \(\hat{p}_i \propto \| \partial L(q_i, y_i) / \partial q_i \| \).

Since \( \| \nabla _{\theta} L(Q_{\theta}(s_i, a_i), y_i) \| = \| \partial L(q_i, y_i) / \partial q_i \cdot \nabla _{\theta} Q_{\theta}(s_i, a_i) \| \), the surrogate priority is good approximation when \( \nabla _{\theta} Q_{\theta}(s_i, a_i) \) is almost constant across samples.

Moreover, when loss function is L2-norm, the surrogate priority becomes TD error.

Although using TD error as priority is not so bad, one of the biggest problems at PER is that the priorities are always outdated. However, re-computing priorities of all transitions in the buffer at every sampling is too expensive.

LaBER first samples \(m\)-times larger batch from the buffer uniformly, then computes surrogate priorities for them, and samples final batch according to the priorities.

According to the authors, LaBER can be used together with non-uniform sampling like PER (they called it as PER-LaBER), however, the combination doesn’t improve the performance so much, even though there are additional computational cost.

cpprb provides three helper classes LaBERmean, LaBERlazy, and LaBERmax. If you don’t have any special reasons, it is better to use LaBERmean, which is theoretically and experimentally best. These classes are constructed with following parameters;

Parameters Type Description Default
batch_size int Desired final batch size (output size)
m int Multiplication factor (input size is m * batch_size) 4
eps float Small positive constant to avoid 0 priority. (keyword only) 1e-6

After construction, these classes can be used as functor. You can call with priorities keyword and any other optional environment values.

laber = LaBERmean(32, 4)

sample = laber(priorities= [ ... ], # 32 * 4 surrogate priorities.
               # optional: any additional environment values can be passed and subsampled together
               obs= # ...
               act= # ...

2 Example Usage

The following pseudo code shows usage.

from cpprb import ReplayBuffer, LaBERmean

buffer_size = int(1e+6)
env_dict = # Define environment values

batch_size = 32
m = 4

n_iteration = int(1e+6)

rb = ReplayBuffer(buffer_size, env_dict)

laber = LaBERmean(batch_size, m)

env = # Create Env
policy = # Create Policy Network

observation = env.reset()
for _ in range(n_iteration):
    action = policy(observation)
    next_observation, reward, done, _ = env.step(action)


    sample = rb.sample(batch_size * m)

    absTD = # Calculate surrogate priority using network

    idx_weights = laber(priorities=absTD)

    indexes = idx_weights["indexes"]
    weights = idx_weights["weights"]

    policy.train((absTD[indexes] * weights).mean())

    if done:
        observation = env.reset()
        observation = next_observation

The full example code are as follow;

# Example for Large Batch Experience Replay (LaBER)
# Ref:

import os
import datetime

import numpy as np

import gym

import tensorflow as tf
from tensorflow.keras.models import Sequential, clone_model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.summary import create_file_writer

from cpprb import ReplayBuffer, LaBERmean

gamma = 0.99
batch_size = 64

N_iteration = int(1e+6)
target_update_freq = 10000
eval_freq = 1000

egreedy = 1.0
decay_egreedy = lambda e: max(e*0.99, 0.1)

# Use 4 times larger batch for initial uniform sampling
# Use LaBER-mean, which is the best variant
m = 4
LaBER = LaBERmean(batch_size, m)

# Log
dir_name ="%Y%m%d-%H%M%S")
logdir = os.path.join("logs", dir_name)
writer = create_file_writer(logdir + "/metrics")

# Env
env = gym.make('CartPole-v1')
eval_env = gym.make('CartPole-v1')

# For CartPole: input 4, output 2
model = Sequential([Dense(64,activation='relu',
target_model = clone_model(model)

# Loss Function

def Huber_loss(absTD):
    return tf.where(absTD > 1.0, absTD, tf.math.square(absTD))

def MSE(absTD):
    return tf.math.square(absTD)

loss_func = Huber_loss

optimizer = Adam()

buffer_size = 1e+6
env_dict = {"obs":{"shape": env.observation_space.shape},
            "act":{"shape": 1,"dtype": np.ubyte},
            "rew": {},
            "next_obs": {"shape": env.observation_space.shape},
            "done": {}}

# Nstep
nstep = 3
# nstep = False

if nstep:
    Nstep = {"size": nstep, "rew": "rew", "next": "next_obs"}
    discount = tf.constant(gamma ** nstep)
    Nstep = None
    discount = tf.constant(gamma)

rb = ReplayBuffer(buffer_size,env_dict,Nstep=Nstep)

def Q_func(model,obs,act,act_shape):
    return tf.reduce_sum(model(obs) * tf.one_hot(act,depth=act_shape), axis=1)

def DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
    return gamma*tf.reduce_max(target(next_obs),axis=1)*(1.0-done) + rew

def Double_DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
    Double DQN:
    act = tf.math.argmax(model(next_obs),axis=1)
    return gamma*tf.reduce_sum(target(next_obs)*tf.one_hot(act,depth=act_shape), axis=1)*(1.0-done) + rew

target_func = Double_DQN_target_func

def evaluate(model,env):
    obs = env.reset()
    total_rew = 0

    while True:
        Q = tf.squeeze(model(obs.reshape(1,-1)))
        act = np.argmax(Q)
        obs, rew, done, _ = env.step(act)
        total_rew += rew

        if done:
            return total_rew

# Start Experiment

observation = env.reset()

# Warming up
for n_step in range(100):
    action = env.action_space.sample() # Random Action
    next_observation, reward, done, info = env.step(action)
    observation = next_observation
    if done:
        observation = env.reset()

n_episode = 0
observation = env.reset()
for n_step in range(N_iteration):

    if np.random.rand() < egreedy:
        action = env.action_space.sample()
        Q = tf.squeeze(model(observation.reshape(1,-1)))
        action = np.argmax(Q)

    egreedy = decay_egreedy(egreedy)

    next_observation, reward, done, info = env.step(action)
    observation = next_observation

    # Uniform sampling
    sample = rb.sample(batch_size * m)

    with tf.GradientTape() as tape:
        Q =  Q_func(model,
        target_Q = tf.stop_gradient(target_func(model,target_model,
        tf.summary.scalar("Target Q", data=tf.reduce_mean(target_Q), step=n_step)
        absTD = tf.math.abs(target_Q - Q)

        # Sub-sample according to surrogate priorities
        #   When loss is L2 or Huber, and no activation at the last layer,
        #   |TD| is surrogate priority.
        sample = LaBER(priorities=absTD)
        indexes = tf.constant(sample["indexes"])
        weights = tf.constant(sample["weights"])

        absTD = tf.gather(absTD, indexes)
        assert absTD.shape == weights.shape, f"BUG: absTD.shape: {absTD.shae}, weights.shape {weights.shape}"

        loss = tf.reduce_mean(loss_func(absTD)*weights)

    grad = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grad, model.trainable_weights))
    tf.summary.scalar("Loss vs training step", data=loss, step=n_step)

    if done:
        observation = env.reset()
        n_episode += 1

    if n_step % target_update_freq == 0:

    if n_step % eval_freq == eval_freq-1:
        eval_rew = evaluate(model,eval_env)
        tf.summary.scalar("episode reward vs training step",data=eval_rew,step=n_step)

3 Notes

We add eps to avoid zero priority, however, the original implementation don’t have it. If you don’t want to add small positive constant, you can pass eps=0 to the constructor (aka. __init__).

4 Technical Detail

Since the surrogate priority usually requires network’s forward caluculation, we implement LaBER separately from replay buffer.

Then LaBERmean etc. become simple classes, so that they are implemented as ordinal Python classes insted of Cython cdef classes.

  1. T. Lahire et al., “Large Batch Experience Replay”, CoRR (2021) (arXiv:2110.01528, code↩︎